package grpc

import (
	"context"
	"gitee.com/llakcs/agile-go/middleware"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

/**
 *   GRPC拦截器
 */

// unaryServerInterceptor is a gRPC unary server interceptor
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		// 创建处理超时的上下文
		if s.timeout > 0 {
			ctx, cancel := context.WithTimeout(ctx, s.timeout)
			defer cancel()
			// 检查上下文是否已经结束
			if ctx.Err() == context.Canceled {
				return nil, status.Errorf(codes.Canceled, "request is canceled")
			}
			if ctx.Err() == context.DeadlineExceeded {
				return nil, status.Errorf(codes.DeadlineExceeded, "deadline is exceeded")
			}
		}
		//转成中间件
		h := func(ctx context.Context, req interface{}) (interface{}, error) {
			return handler(ctx, req)
		}
		if next := s.middleware.Match(info.FullMethod); len(next) > 0 {
			h = middleware.Chain(next...)(h)
		}
		reply, err := h(ctx, req)
		return reply, err
	}
}

// wrappedStream is rewrite grpc stream's context
type wrappedStream struct {
	grpc.ServerStream
	ctx context.Context
}

func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
	return &wrappedStream{
		ServerStream: stream,
		ctx:          ctx,
	}
}

func (w *wrappedStream) Context() context.Context {
	return w.ctx
}

// streamServerInterceptor is a gRPC stream server interceptor
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		// 创建处理超时的上下文
		ctx, cancel := context.WithTimeout(ss.Context(), s.timeout)
		defer cancel()
		// 使用新的超时上下文创建一个新的 ServerStream
		ws := NewWrappedStream(ctx, ss)
		// 继续处理请求
		err := handler(srv, ws)
		if err != nil {
			return err
		}
		return nil
	}
}
